#!/usr/bin/python
# encoding: utf-8
import os
from torch.utils.data import Dataset
from PIL import Image
import torch


class GTResDataset(Dataset):

    def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
        self.pairs = []
        for f in os.listdir(root_path):
            image_path = os.path.join(root_path, f)
            gt_path = os.path.join(gt_dir, f)
            if f.endswith(".jpg") or f.endswith(".png"):
                self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
        self.transform = transform
        self.transform_train = transform_train

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, index):
        from_path, to_path, _ = self.pairs[index]
        from_im = Image.open(from_path).convert('RGB')
        to_im = Image.open(to_path).convert('RGB')

        if self.transform:
            to_im = self.transform(to_im)
            from_im = self.transform(from_im)

        return from_im, to_im
